"""
Phase 4b: Merge expanded bond yields into bilateral panel and re-estimate.

Uses expanded_bond_yields.csv (35 countries) already downloaded.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys

sys.path.insert(0, str(Path("/mnt/c/demographics_capital_flows/gravity_bilateral")))
from src.model import PanelGLS

BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral")
MAIN_DIR = Path("/mnt/c/demographics_capital_flows/multilateral")
PROCESSED_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"
RAW_DIR = BASE_DIR / "data" / "raw"


def estimate_model(df, dep_var, regressors, model_name, year_dummies=True):
    """Estimate a gravity model specification."""
    years = sorted(df['year'].dropna().unique())
    yr_cols = [f'yr_{int(y)}' for y in years[1:]] if year_dummies else []
    yr_cols = [c for c in yr_cols if c in df.columns]
    all_vars = regressors + yr_cols

    est = df.dropna(subset=[dep_var] + all_vars + ['pair_id', 'year']).copy()
    if len(est) < 100:
        print(f"  {model_name}: insufficient obs ({len(est)})")
        return None, None

    print(f"\n{'=' * 70}")
    print(f"  {model_name}")
    print(f"  N = {len(est):,}, Pairs = {est['pair_id'].nunique():,}")
    print(f"  Years: {est['year'].min():.0f}-{est['year'].max():.0f}")
    print(f"{'=' * 70}")

    y = est[dep_var].values
    X = est[all_vars].values

    gls = PanelGLS()
    gls.fit(y, X, est['pair_id'].values, est['year'].values.astype(int))

    print(f"  R² = {gls.r_squared:.4f}, ρ = {gls.rho:.3f}")
    print(f"  {'Variable':<35} {'Coef':>10} {'SE':>10} {'p-val':>8}")
    print(f"  {'-' * 65}")
    for i, v in enumerate(regressors):
        sig = '***' if gls.pvalues[i] < 0.01 else '**' if gls.pvalues[i] < 0.05 else '*' if gls.pvalues[i] < 0.1 else ''
        print(f"  {v:<35} {gls.beta[i]:>10.4f} {gls.se[i]:>10.4f} {gls.pvalues[i]:>8.4f} {sig}")

    results = []
    for i, v in enumerate(regressors):
        results.append({
            'model': model_name, 'variable': v,
            'coefficient': gls.beta[i], 'std_error': gls.se[i],
            't_stat': gls.tvalues[i], 'p_value': gls.pvalues[i],
        })
    for meta_var, meta_val in [('_R_squared', gls.r_squared), ('_N_obs', gls.n_obs),
                                ('_N_pairs', gls.n_countries), ('_rho', gls.rho)]:
        results.append({'model': model_name, 'variable': meta_var,
                        'coefficient': meta_val, 'std_error': np.nan,
                        't_stat': np.nan, 'p_value': np.nan})

    return gls, pd.DataFrame(results)


def main():
    print("=" * 70)
    print("PHASE 4B: EXPANDED YIELD RE-ESTIMATION")
    print("=" * 70)

    # Load expanded yields
    yields = pd.read_csv(RAW_DIR / "expanded_bond_yields.csv")
    print(f"\nExpanded yields: {yields['iso3'].nunique()} countries, {len(yields)} obs")

    # Load inflation/GDP for computing real rates
    fp = pd.read_csv(MAIN_DIR / "followup" / "data" / "processed" / "full_panel.csv",
                      usecols=['iso3', 'year', 'inflation', 'ngdp_usd', 'Z_1', 'Z_2', 'Z_3'])
    fp = fp[fp['year'] <= 2024]

    # Compute real bond yield and differential
    yields_real = yields.merge(fp[['iso3', 'year', 'inflation', 'ngdp_usd']],
                                on=['iso3', 'year'], how='inner')
    yields_real['real_bond_10y'] = yields_real['govt_bond_10y'] - yields_real['inflation']

    # GDP-weighted world average
    world_avg = (yields_real.dropna(subset=['real_bond_10y', 'ngdp_usd'])
                 .groupby('year')
                 .apply(lambda x: np.average(x['real_bond_10y'],
                                              weights=x['ngdp_usd'].clip(lower=0.1)),
                        include_groups=False)
                 .rename('real_bond_10y_world'))
    yields_real = yields_real.merge(world_avg.reset_index(), on='year', how='left')
    yields_real['real_bond_10y_diff'] = yields_real['real_bond_10y'] - yields_real['real_bond_10y_world']

    # ======================================================================
    # STEP 1: Re-estimate S1 on expanded country set
    # ======================================================================
    print("\n" + "=" * 70)
    print("STEP 1: S1 ESTIMATION ON EXPANDED COUNTRY SET")
    print("=" * 70)

    s1_data = yields_real.merge(fp[['iso3', 'year', 'Z_1', 'Z_2', 'Z_3']].dropna(),
                                 on=['iso3', 'year'], how='inner')
    s1_data = s1_data.dropna(subset=['real_bond_10y_diff', 'Z_1', 'Z_2', 'Z_3'])

    print(f"\n  S1 sample: {len(s1_data)} obs, {s1_data['iso3'].nunique()} countries")
    print(f"  Countries: {sorted(s1_data['iso3'].unique())}")

    # Original S1 (23 OECD only)
    orig_countries = ['AUS', 'AUT', 'BEL', 'CAN', 'CHE', 'DEU', 'DNK', 'ESP', 'FIN',
                      'FRA', 'GBR', 'GRC', 'IRL', 'ITA', 'JPN', 'KOR', 'MEX', 'NLD',
                      'NOR', 'NZL', 'PRT', 'SWE', 'USA']
    s1_orig = s1_data[s1_data['iso3'].isin(orig_countries)]

    print(f"\n  --- Original S1 (23 OECD) ---")
    s1_gls_orig = PanelGLS()
    s1_gls_orig.fit(s1_orig['real_bond_10y_diff'].values,
                     s1_orig[['Z_1', 'Z_2', 'Z_3']].values,
                     s1_orig['iso3'].values, s1_orig['year'].values.astype(int))
    print(f"  R² = {s1_gls_orig.r_squared:.4f}, N = {s1_gls_orig.n_obs}, Countries = {s1_gls_orig.n_countries}")
    for i, v in enumerate(['Z_1', 'Z_2', 'Z_3']):
        sig = '***' if s1_gls_orig.pvalues[i] < 0.01 else '**' if s1_gls_orig.pvalues[i] < 0.05 else '*' if s1_gls_orig.pvalues[i] < 0.1 else ''
        print(f"    {v}: {s1_gls_orig.beta[i]:.4f} (SE={s1_gls_orig.se[i]:.4f}, p={s1_gls_orig.pvalues[i]:.4f}) {sig}")

    print(f"\n  --- Expanded S1 (35 countries) ---")
    s1_gls_exp = PanelGLS()
    s1_gls_exp.fit(s1_data['real_bond_10y_diff'].values,
                    s1_data[['Z_1', 'Z_2', 'Z_3']].values,
                    s1_data['iso3'].values, s1_data['year'].values.astype(int))
    print(f"  R² = {s1_gls_exp.r_squared:.4f}, N = {s1_gls_exp.n_obs}, Countries = {s1_gls_exp.n_countries}")
    for i, v in enumerate(['Z_1', 'Z_2', 'Z_3']):
        sig = '***' if s1_gls_exp.pvalues[i] < 0.01 else '**' if s1_gls_exp.pvalues[i] < 0.05 else '*' if s1_gls_exp.pvalues[i] < 0.1 else ''
        print(f"    {v}: {s1_gls_exp.beta[i]:.4f} (SE={s1_gls_exp.se[i]:.4f}, p={s1_gls_exp.pvalues[i]:.4f}) {sig}")

    # Save S1 comparison
    s1_comparison = pd.DataFrame([
        {'model': 'S1_original_23', 'variable': v, 'coefficient': s1_gls_orig.beta[i],
         'std_error': s1_gls_orig.se[i], 'p_value': s1_gls_orig.pvalues[i],
         'n_obs': s1_gls_orig.n_obs, 'n_countries': s1_gls_orig.n_countries,
         'r_squared': s1_gls_orig.r_squared}
        for i, v in enumerate(['Z_1', 'Z_2', 'Z_3'])
    ] + [
        {'model': 'S1_expanded_35', 'variable': v, 'coefficient': s1_gls_exp.beta[i],
         'std_error': s1_gls_exp.se[i], 'p_value': s1_gls_exp.pvalues[i],
         'n_obs': s1_gls_exp.n_obs, 'n_countries': s1_gls_exp.n_countries,
         'r_squared': s1_gls_exp.r_squared}
        for i, v in enumerate(['Z_1', 'Z_2', 'Z_3'])
    ])
    s1_comparison.to_csv(OUTPUT_DIR / "s1_expanded_coefficients.csv", index=False)

    # ======================================================================
    # STEP 2: Merge expanded yields into bilateral panel
    # ======================================================================
    print("\n" + "=" * 70)
    print("STEP 2: MERGE INTO BILATERAL PANEL")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "bilateral_panel.csv")
    print(f"\n  Bilateral panel: {len(df):,} obs, {df['pair_id'].nunique():,} pairs")

    # Create year dummies
    years = sorted(df['year'].dropna().unique())
    for y in years[1:]:
        df[f'yr_{int(y)}'] = (df['year'] == y).astype(int)

    # Merge yields as reporter (i) and partner (j)
    yield_cols = yields_real[['iso3', 'year', 'real_bond_10y_diff']].rename(
        columns={'real_bond_10y_diff': 'rbond_exp_diff'})

    df = df.merge(yield_cols.rename(columns={'iso3': 'reporter', 'rbond_exp_diff': 'rbond_exp_i'}),
                  on=['reporter', 'year'], how='left')
    df = df.merge(yield_cols.rename(columns={'iso3': 'partner', 'rbond_exp_diff': 'rbond_exp_j'}),
                  on=['partner', 'year'], how='left')

    df['expanded_rate_diff_ij'] = df['rbond_exp_i'] - df['rbond_exp_j']

    # Coverage comparison
    old_rate = df['rate_diff_ij'].notna().sum()
    new_rate = df['expanded_rate_diff_ij'].notna().sum()
    old_pairs = df[df['rate_diff_ij'].notna()]['pair_id'].nunique()
    new_pairs = df[df['expanded_rate_diff_ij'].notna()]['pair_id'].nunique()

    portfolio = df['log_portfolio_total'].notna()
    old_p = (portfolio & df['rate_diff_ij'].notna()).sum()
    new_p = (portfolio & df['expanded_rate_diff_ij'].notna()).sum()
    old_pp = df[portfolio & df['rate_diff_ij'].notna()]['pair_id'].nunique()
    new_pp = df[portfolio & df['expanded_rate_diff_ij'].notna()]['pair_id'].nunique()

    print(f"\n  COVERAGE COMPARISON:")
    print(f"  {'':30} {'Original':>12} {'Expanded':>12} {'Improvement':>12}")
    print(f"  {'-' * 68}")
    print(f"  {'All obs with rate diff':<30} {old_rate:>12,} {new_rate:>12,} {new_rate-old_rate:>+12,}")
    print(f"  {'Pairs with rate diff':<30} {old_pairs:>12,} {new_pairs:>12,} {new_pairs-old_pairs:>+12,}")
    print(f"  {'Portfolio obs with rate diff':<30} {old_p:>12,} {new_p:>12,} {new_p-old_p:>+12,}")
    print(f"  {'Portfolio pairs with rate diff':<30} {old_pp:>12,} {new_pp:>12,} {new_pp-old_pp:>+12,}")

    # ======================================================================
    # STEP 3: Re-estimate gravity models
    # ======================================================================
    print("\n" + "=" * 70)
    print("STEP 3: RE-ESTIMATE GRAVITY MODELS")
    print("=" * 70)

    dep_var = 'log_portfolio_total'
    gravity_vars = ['log_dist', 'contiguity', 'common_lang_official', 'colonial_ties', 'log_gdp_product']
    demo_vars = ['dZ_1', 'dZ_2', 'dZ_3']

    all_results = []

    # Reference: Baseline (same as before)
    gls_base, res_base = estimate_model(df, dep_var, gravity_vars, "2a: Baseline Gravity")
    if res_base is not None: all_results.append(res_base)

    # Reference: Full demographics (same as before)
    gls_demo, res_demo = estimate_model(df, dep_var, gravity_vars + demo_vars, "2b: Full Demographics")
    if res_demo is not None: all_results.append(res_demo)

    # Model 2e-original: Demographics + original rate diff (506 OECD pairs)
    gls_2e_orig, res_2e_orig = estimate_model(
        df, dep_var, gravity_vars + demo_vars + ['rate_diff_ij'],
        "2e-orig: Demographics + Original Rate Diff (23 ctry)")
    if res_2e_orig is not None: all_results.append(res_2e_orig)

    # Model 2e-expanded: Demographics + expanded rate diff (35 countries)
    gls_2e_exp, res_2e_exp = estimate_model(
        df, dep_var, gravity_vars + demo_vars + ['expanded_rate_diff_ij'],
        "2e-exp: Demographics + Expanded Rate Diff (35 ctry)")
    if res_2e_exp is not None: all_results.append(res_2e_exp)

    # Model 2f-orig: Fitted rate diff using original S1 (23 OECD coefficients)
    ORIG_S1 = {'Z_1': 16.3197, 'Z_2': -2.0746, 'Z_3': 0.0718}
    df['fitted_rate_orig'] = (ORIG_S1['Z_1'] * df['dZ_1'] +
                               ORIG_S1['Z_2'] * df['dZ_2'] +
                               ORIG_S1['Z_3'] * df['dZ_3'])
    gls_2f_orig, res_2f_orig = estimate_model(
        df, dep_var, gravity_vars + ['fitted_rate_orig'],
        "2f-orig: Fitted Rate Diff (S1 from 23 OECD)")
    if res_2f_orig is not None: all_results.append(res_2f_orig)

    # Model 2f-expanded: Fitted rate diff using expanded S1 (35 country coefficients)
    df['fitted_rate_exp'] = (s1_gls_exp.beta[0] * df['dZ_1'] +
                              s1_gls_exp.beta[1] * df['dZ_2'] +
                              s1_gls_exp.beta[2] * df['dZ_3'])
    gls_2f_exp, res_2f_exp = estimate_model(
        df, dep_var, gravity_vars + ['fitted_rate_exp'],
        "2f-exp: Fitted Rate Diff (S1 from 35 ctry)")
    if res_2f_exp is not None: all_results.append(res_2f_exp)

    # ======================================================================
    # STEP 4: Mediation decompositions
    # ======================================================================
    print("\n" + "=" * 70)
    print("STEP 4: MEDIATION DECOMPOSITIONS")
    print("=" * 70)

    if gls_base and gls_demo:
        r2_base = gls_base.r_squared
        r2_demo = gls_demo.r_squared
        total = r2_demo - r2_base

        decomp_rows = [
            {'metric': 'R2_baseline', 'value': r2_base},
            {'metric': 'R2_full_demographics', 'value': r2_demo},
            {'metric': 'total_demo_R2_improvement', 'value': total},
        ]

        for label, gls_fitted in [
            ('S1_original_23', gls_2f_orig),
            ('S1_expanded_35', gls_2f_exp),
        ]:
            if gls_fitted:
                r2_rate = gls_fitted.r_squared
                rate = r2_rate - r2_base
                direct = total - rate
                rate_pct = rate / total * 100 if total > 0 else np.nan
                direct_pct = direct / total * 100 if total > 0 else np.nan

                print(f"\n  {label}:")
                print(f"    Fitted Δr̂ R²:           {r2_rate:.4f} (+{rate:.4f})")
                print(f"    Rate-mediated share:     {rate_pct:.1f}%")
                print(f"    Direct channel share:    {direct_pct:.1f}%")

                decomp_rows.extend([
                    {'metric': f'R2_fitted_rate_{label}', 'value': r2_rate},
                    {'metric': f'rate_channel_{label}_pct', 'value': rate_pct},
                    {'metric': f'direct_channel_{label}_pct', 'value': direct_pct},
                ])

        pd.DataFrame(decomp_rows).to_csv(OUTPUT_DIR / "mediation_decomposition_expanded.csv", index=False)

    # ======================================================================
    # STEP 5: Key comparison table
    # ======================================================================
    print("\n" + "=" * 70)
    print("STEP 5: KEY COMPARISONS")
    print("=" * 70)

    print(f"\n  A. Does expanding the S1 sample strengthen the first stage?")
    print(f"  {'':20} {'S1 (23 OECD)':>20} {'S1 (35 countries)':>20}")
    print(f"  {'-' * 62}")
    for i, v in enumerate(['Z_1', 'Z_2', 'Z_3']):
        p_orig = s1_gls_orig.pvalues[i]
        p_exp = s1_gls_exp.pvalues[i]
        print(f"  {v:<20} {s1_gls_orig.beta[i]:>8.3f} (p={p_orig:.3f}) {s1_gls_exp.beta[i]:>8.3f} (p={p_exp:.3f})")
    print(f"  {'R²':<20} {s1_gls_orig.r_squared:>20.4f} {s1_gls_exp.r_squared:>20.4f}")
    print(f"  {'N obs':<20} {s1_gls_orig.n_obs:>20} {s1_gls_exp.n_obs:>20}")
    print(f"  {'N countries':<20} {s1_gls_orig.n_countries:>20} {s1_gls_exp.n_countries:>20}")

    if gls_2e_orig and gls_2e_exp:
        print(f"\n  B. Does expanding yield coverage change Model 2e?")
        print(f"  {'':25} {'2e (23 ctry)':>15} {'2e (35 ctry)':>15}")
        print(f"  {'-' * 57}")
        for v in demo_vars:
            orig_row = res_2e_orig[res_2e_orig['variable'] == v].iloc[0] if len(res_2e_orig[res_2e_orig['variable'] == v]) > 0 else None
            exp_row = res_2e_exp[res_2e_exp['variable'] == v].iloc[0] if len(res_2e_exp[res_2e_exp['variable'] == v]) > 0 else None
            if orig_row is not None and exp_row is not None:
                print(f"  {v:<25} {orig_row['coefficient']:>7.3f} (p={orig_row['p_value']:.3f}) "
                      f"{exp_row['coefficient']:>7.3f} (p={exp_row['p_value']:.3f})")

        # Rate diff itself
        for v in ['rate_diff_ij', 'expanded_rate_diff_ij']:
            for label, res in [('2e (23 ctry)', res_2e_orig), ('2e (35 ctry)', res_2e_exp)]:
                row = res[res['variable'] == v]
                if len(row) > 0:
                    r = row.iloc[0]
                    print(f"  {v:<25} {r['coefficient']:>7.4f} (p={r['p_value']:.3f}) in {label}")

        # R² and N
        for meta in ['_R_squared', '_N_obs', '_N_pairs']:
            orig_val = res_2e_orig.loc[res_2e_orig['variable'] == meta, 'coefficient'].values
            exp_val = res_2e_exp.loc[res_2e_exp['variable'] == meta, 'coefficient'].values
            if len(orig_val) > 0 and len(exp_val) > 0:
                fmt = '.4f' if meta == '_R_squared' else ',.0f'
                print(f"  {meta:<25} {orig_val[0]:{fmt}} {'':>10} {exp_val[0]:{fmt}}")

    # Save all results
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(OUTPUT_DIR / "gravity_results_expanded.csv", index=False)
        print(f"\n  Saved: {OUTPUT_DIR / 'gravity_results_expanded.csv'}")


if __name__ == "__main__":
    main()
